In this assignment, I would like you to develop a random forest to classify the MNIST dataset. Here are the basic requirements:
This assignment is essentially a combination of parts of the "module3_2_multiclassV2.ipynb" workbook, as well as the "module4_2_random_forest.ipynb" workbook.
NOTE: This whole assignment can be done using the "short" dataset. You can (if you want) use the full dataset, but it is not necessary. It will take much longer to run if you do!
Use the structure below to craft your solution.
import plotly.express as px
import plotly.io as pio
pio.renderers.default='notebook'
%matplotlib inline
import pandas as pd
#
# Data location
data_location = '/fs/ess/PAS2038/PHYSICS5680_OSU/data'
#
# Define our "signal" digit
short = ""
#short = "short_"
#
# Read in all of the other digits
dfCombined = pd.DataFrame()
for digit in range(10):
print("Processing digit ",digit)
fname = data_location + '/ch3/digit_' + short + str(digit) + '.csv'
df = pd.read_csv(fname,header=None)
df['digit'] = digit
dfCombined = pd.concat([dfCombined, df])
print("Length of sample: ",len(dfCombined))
num_features = 784
Processing digit 0 Processing digit 1 Processing digit 2 Processing digit 3 Processing digit 4 Processing digit 5 Processing digit 6 Processing digit 7 Processing digit 8 Processing digit 9 Length of sample: 70000
# Used to implement the multi-dimensional counter we need in the performance class
from collections import defaultdict
from functools import partial
from itertools import repeat
def nested_defaultdict(default_factory, depth=1):
result = partial(defaultdict, default_factory)
for _ in repeat(None, depth - 1):
result = partial(defaultdict, result)
return result()
# Determine the performance
def multiPerformance(y,y_pred,y_score,debug=False):
#
# Make our matrix
confusionMatrix = nested_defaultdict(int,2)
classes = set()
totalTrue = defaultdict(int)
totalPred = defaultdict(int)
for i in range(len(y_pred)):
trueClass = y[i]
classes.add(trueClass)
predClass = y_pred[i]
totalTrue[trueClass] += 1
totalPred[predClass] += 1
confusionMatrix[trueClass][predClass] += 1
if debug:
for trueClass in classes:
print("True: ",trueClass,end="")
for predClass in classes:
print("\t",confusionMatrix[trueClass][predClass],end="")
print()
print()
#
#
# Overall accuracy - sum the diagonals and divide by total
accMicro = 0.0
accMacro = 0.0
for cl in classes:
accMicro += confusionMatrix[cl][cl]
accMacro += confusionMatrix[cl][cl]/totalTrue[cl]
accMicro /= len(y)
accMacro = accMacro / len(classes)
results = {"confusionMatrix":confusionMatrix,"accuracyMicro":accMicro,"accuracyMacro":accMacro}
return results
Get this from the module3_2_multiclassV2 notebook.
def runFitter(estimator,X_train,y_train,X_test,y_test,debug=False):
#
# Now fit to our training set
estimator.fit(X_train,y_train)
#
# Now predict the classes and get the score for our traing set
y_train_pred = estimator.predict(X_train)
y_train_score = estimator.predict_proba(X_train)[:,1] # NOTE: some estimators have a predict_prob method instead od descision_function
#
# Now predict the classes and get the score for our test set
y_test_pred = estimator.predict(X_test)
y_test_score = estimator.predict_proba(X_test)[:,1]
#
# Now get the performaance
results_test = multiPerformance(y_test,y_test_pred,y_test_score,debug=False)
results_train = multiPerformance(y_train,y_train_pred,y_train_score,debug=False)
#
# Decide what you want to return: for now, just precision, recall, and auc for both test and train
results = {
'cf_test':results_test['confusionMatrix'],
'cf_train':results_train['confusionMatrix'],
'accuracyMicro_test':results_test['accuracyMicro'],
'accuracyMacro_test':results_test['accuracyMacro'],
'accuracyMicro_train':results_train['accuracyMicro'],
'accuracyMacro_train':results_train['accuracyMacro'],
}
return results
Get this also from module3_2_multiclassV2.
As in that case, we will split the data into two datasets:
Remember that the features are the first 784 columns, and the labels are given by the "digit" column.
#
# Your code here
from sklearn.utils import shuffle
dfCombinedShuffle = shuffle(dfCombined,random_state=42) # by setting the random state we will get reproducible results
train_length = int(0.9*len(dfCombinedShuffle))
X = dfCombinedShuffle.iloc[:train_length,:784].to_numpy()
y = dfCombinedShuffle.iloc[:train_length,784].values
X_holdout = dfCombinedShuffle.iloc[train_length:,:784].to_numpy()
y_holdout = dfCombinedShuffle.iloc[train_length:,784].values
print(len(X),len(X_holdout))
63000 7000
Use a reasonable set of parameters for the forest estimator: for example, n_estimators=100, max_depth=5. Then perform one run of k-fold (with k=5) validation to see what sort of accuracy you get.
Print the average accuracy from the 5 kfolds for each digit.
An overall average macro accuracy of 86% on the test set is reasonable with these parameters.
from sklearn.ensemble import RandomForestClassifier
estimator = RandomForestClassifier(n_estimators=100, max_depth=5,random_state=42)
results = runFitter(estimator,X,y,X,y,debug=False)
#WHY ARE THE ACCURACIES THE SAME FOR TEST AND TRAIN
print('A run with n_estimators=100 and max_depth = 5')
print('accuracyMicro for train set: ' + str(results['accuracyMicro_train']))
print('accuracyMacro for train set: ' + str(results['accuracyMacro_train']))
print('accuracyMicro for test set: ' + str(results['accuracyMicro_test']))
print('accuracyMacro for test set: ' + str(results['accuracyMacro_test']))
A run with n_estimators=100 and max_depth = 5 accuracyMicro for train set: 0.8591746031746031 accuracyMacro for train set: 0.8557300597127874 accuracyMicro for test set: 0.8591746031746031 accuracyMacro for test set: 0.8557300597127874
#
# Your code here
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import KFold
kfolds = 5
#skf = StratifiedKFold(n_splits=kfolds)
skf = KFold(n_splits=kfolds)
import numpy as np
# Get our estimator and predict
estimator = RandomForestClassifier(n_estimators=100, max_depth=5,random_state=42)
# Cresate some vars to keep track of everything
avg_accuracyMicro_test = 0.0
avg_accuracyMicro_train = 0.0
avg_accuracyMacro_test = 0.0
avg_accuracyMacro_train = 0.0
numSplits = 0.0
#
# Now loop
accuracies_by_digit = defaultdict(float)
fakeratio_by_digit = defaultdict(float)
for train_index, test_index in skf.split(X, y):
print("Training")
X_train = X[train_index]
y_train = y[train_index]
X_test = X[test_index]
y_test = y[test_index]
#
# Now fit to our training set
results = runFitter(estimator,X_train,y_train,X_test,y_test)
#
avg_accuracyMicro_test += results['accuracyMicro_test']
avg_accuracyMicro_train += results['accuracyMicro_train']
avg_accuracyMacro_test += results['accuracyMacro_test']
avg_accuracyMacro_train += results['accuracyMacro_train']
lastCF_train = results['cf_train']
lastCF_test = results['cf_test']
numSplits += 1.0
print(" Split ",numSplits,"; accuracyMicro test/train",results['accuracyMicro_test'],results['accuracyMicro_train'],"; accuracyMacro test/train",results['accuracyMacro_test'],results['accuracyMacro_train'])
#
for i in range(10):
if bool(accuracies_by_digit[f'Digit {i}']) == True: #this is to see if I have already gone through this at least once, thus they exist in the dictionary
total_accuracy = accuracies_by_digit[f'Digit {i}'] #if it exist, then there is a value in the dict already of the digit
classified_correct = results['cf_test'][i][i] #the diagonal values are the correctly classified values
total_true_digits = sum(results['cf_test'][i].values()) #summing the row obtains how many of said digit was actually in the set
total_accuracy += classified_correct/total_true_digits #calculates accuracy and adds it to the total accuracy thats in the dict which is later divided for the average
accuracies_by_digit.update({f'Digit {i}': total_accuracy}) #update value in dict of accuracy
if bool(accuracies_by_digit[f'Digit {i}']) == False: #this is for when we first go thru and the dict does not have anything for any digit
#print('i am in false if')
classified_correct = results['cf_test'][i][i]
total_true_digits = sum(results['cf_test'][i].values())
total_accuracy = classified_correct/total_true_digits
accuracies_by_digit.update({f'Digit {i}': total_accuracy})
if numSplits == kfolds: #when this is for when we are in the last fold and thus we can now take the average
folds = float(kfolds)
total_accuracy = accuracies_by_digit[f'Digit {i}']/folds
accuracies_by_digit.update({f'Digit {i}': total_accuracy})
print(f'average accuracy for digit {i}: ',str(accuracies_by_digit[f'Digit {i}']))
avg_accuracyMicro_test /= numSplits
avg_accuracyMicro_train /= numSplits
avg_accuracyMacro_test /= numSplits
avg_accuracyMacro_train /= numSplits
# Now print
print("Average Micro Accuracy train/test ",round(avg_accuracyMicro_train,3),round(avg_accuracyMicro_test,3))
print("Average Macro Accuracy train/test ",round(avg_accuracyMacro_train,3),round(avg_accuracyMacro_test,3))
Training Split 1.0 ; accuracyMicro test/train 0.8584920634920635 0.8582936507936508 ; accuracyMacro test/train 0.8550026320185321 0.854553863049085 Training Split 2.0 ; accuracyMicro test/train 0.8523015873015873 0.8599801587301588 ; accuracyMacro test/train 0.8505649614230805 0.8563950580881963 Training Split 3.0 ; accuracyMicro test/train 0.8580952380952381 0.8596428571428572 ; accuracyMacro test/train 0.8546320520160592 0.8562087605759645 Training Split 4.0 ; accuracyMicro test/train 0.8576190476190476 0.8623412698412698 ; accuracyMacro test/train 0.8536619641033599 0.8592497795382869 Training Split 5.0 ; accuracyMicro test/train 0.8615873015873016 0.8608531746031746 ; accuracyMacro test/train 0.8583299439235998 0.8577590216308184 average accuracy for digit 0: 0.9610787712451507 average accuracy for digit 1: 0.9777301547145256 average accuracy for digit 2: 0.8400432383812877 average accuracy for digit 3: 0.8445075902863233 average accuracy for digit 4: 0.8232460361193633 average accuracy for digit 5: 0.6650244382165857 average accuracy for digit 6: 0.9227685295033415 average accuracy for digit 7: 0.8867836830553179 average accuracy for digit 8: 0.7850277312825888 average accuracy for digit 9: 0.8381729341647783 Average Micro Accuracy train/test 0.86 0.858 Average Macro Accuracy train/test 0.857 0.854
Here you want to loop from max_depth=2 to 22. You could do this in increments of 2 (4,6,8..,20,22) to save time.
Use n_estimators=50 to save even more time.
Look at the "Overfitting vs Underfitting" section of module4_2_random_forest.ipynb. However, in your dfError object, you will only need to save "1-avg_accuracyMacro" for test and train.
import numpy as np
# Create a dataframe to store our results
dfError = pd.DataFrame(columns=['max_depth','trainError_micro','testError_micro',
'trainError_macro','testError_macro'])
# Now loop
for max_depth in range(1,23):
print("Training with max depth =",max_depth)
# Create some vars to keep track of everything
avg_accuracyMicro_test = 0.0
avg_accuracyMicro_train = 0.0
avg_accuracyMacro_test = 0.0
avg_accuracyMacro_train = 0.0
numSplits = 0.0
estimator = RandomForestClassifier(n_estimators=100,random_state=42,max_depth=max_depth)
accuracies_by_digit = defaultdict(float)
fakeratio_by_digit = defaultdict(float)
for train_index, test_index in skf.split(X, y):
#print("Training")
X_train = X[train_index]
y_train = y[train_index]
X_test = X[test_index]
y_test = y[test_index]
#
# Now fit to our training set
results = runFitter(estimator,X_train,y_train,X_test,y_test)
#
# how does this accuracy micro work?
avg_accuracyMicro_test += results['accuracyMicro_test']
avg_accuracyMicro_train += results['accuracyMicro_train']
avg_accuracyMacro_test += results['accuracyMacro_test']
avg_accuracyMacro_train += results['accuracyMacro_train']
lastCF_train = results['cf_train']
lastCF_test = results['cf_test']
numSplits += 1.0
#print(" Split ",numSplits,"; accuracyMicro test/train",results['accuracyMicro_test'],results['accuracyMicro_train'],"; accuracyMacro test/train",results['accuracyMacro_test'],results['accuracyMacro_train'])
avg_accuracyMicro_test /= numSplits
avg_accuracyMicro_train /= numSplits
avg_accuracyMacro_test /= numSplits
avg_accuracyMacro_train /= numSplits
# # Now print
# print("Average Micro Accuracy train/test ",round(avg_accuracyMicro_train,3),round(avg_accuracyMicro_test,3))
# print("Average Macro Accuracy train/test ",round(avg_accuracyMacro_train,3),round(avg_accuracyMacro_test,3))
# Fill dataframe
dfError = dfError.append({
'max_depth':max_depth,
'trainError_Micro':1.0-avg_accuracyMicro_train,'testError_Micro':1.0-avg_accuracyMicro_test,
'trainError_Macro':1.0-avg_accuracyMacro_train,'testError_Macro':1.0-avg_accuracyMacro_test
}, ignore_index=True)
Training with max depth = 1 Training with max depth = 2 Training with max depth = 3 Training with max depth = 4 Training with max depth = 5 Training with max depth = 6 Training with max depth = 7 Training with max depth = 8 Training with max depth = 9 Training with max depth = 10 Training with max depth = 11 Training with max depth = 12 Training with max depth = 13 Training with max depth = 14 Training with max depth = 15 Training with max depth = 16 Training with max depth = 17 Training with max depth = 18 Training with max depth = 19 Training with max depth = 20 Training with max depth = 21 Training with max depth = 22
This plot should have both train and test results on it.
import plotly.express as px
import plotly.io as pio
pio.renderers.default='notebook'
fig = px.line(dfError,x='max_depth', y=['trainError_Micro','testError_Micro'],
title='Error (Micro Accuracy) vs Model Complexity')
fig.show()
fig = px.line(dfError,x='max_depth', y=['trainError_Macro','testError_Macro'],
title='Error (Macro Accuracy) vs Model Complexity')
fig.show()
As noted above, we want to find the feature importance of the pixels. However, we want to do this with the final model. What is the final model? It is the model with the ideal max_depth that we determined from Task 6. Choose the max_depth where the test error is small and the model is simplest (where the test accuracy begins to plateau).
So you will want to re-train the model, using this max_depth, as well as the full data. Look at how this is done at the end of "module4_2_random_forest.ipynb".
Also, to plot the feature importance of the pixels, you will have to reshape the returned importances (with are of length 784) to an array of shape (28,28). Then you can plot them as a heatmap. For plotly express this is best done using px.imshow.
Also, are you sure that the orientation of the heatmap is correct? Plot a sinlge digit (like X_train[0].reshape(28,28)) to check this.
from sklearn.utils import shuffle
dfCombinedShuffle = shuffle(dfCombined,random_state=42) # by setting the random state we will get reproducible results
X = dfCombinedShuffle.iloc[:,:num_features].to_numpy()
y = dfCombinedShuffle.iloc[:,num_features].values
from sklearn.ensemble import RandomForestClassifier
estimator = RandomForestClassifier(n_estimators=100, max_depth=8,random_state=42,oob_score=True)
results = runFitter(estimator,X,y,X,y,debug=False)
#
# printout feature importance
print('A run with n_estimators=100 and max_depth = 7')
print('accuracyMicro for train set: ' + str(results['accuracyMicro_train']))
print('accuracyMacro for train set: ' + str(results['accuracyMacro_train']))
print('accuracyMicro for test set: ' + str(results['accuracyMicro_test']))
print('accuracyMacro for test set: ' + str(results['accuracyMacro_test']))
importanceByName = {}
#print("unsorted importance")
for name,importance in zip(dfCombinedShuffle.columns[:num_features],estimator.feature_importances_):
importanceByName[name] = importance
print("Sorted importance")
for name in sorted(importanceByName, key=importanceByName.get, reverse=True):
print("Pixel,importance",name,round(importanceByName[name],3))
A run with n_estimators=100 and max_depth = 7 accuracyMicro for train set: 0.9313285714285714 accuracyMacro for train set: 0.9305332198240922 accuracyMicro for test set: 0.9313285714285714 accuracyMacro for test set: 0.9305332198240922 Sorted importance Pixel,importance 350 0.014 Pixel,importance 378 0.012 Pixel,importance 409 0.011 Pixel,importance 461 0.011 Pixel,importance 433 0.01 Pixel,importance 406 0.01 Pixel,importance 568 0.009 Pixel,importance 437 0.009 Pixel,importance 375 0.009 Pixel,importance 155 0.009 Pixel,importance 318 0.009 Pixel,importance 465 0.009 Pixel,importance 542 0.008 Pixel,importance 514 0.008 Pixel,importance 429 0.008 Pixel,importance 489 0.008 Pixel,importance 405 0.008 Pixel,importance 460 0.008 Pixel,importance 569 0.007 Pixel,importance 377 0.007 Pixel,importance 462 0.007 Pixel,importance 434 0.007 Pixel,importance 238 0.007 Pixel,importance 515 0.007 Pixel,importance 656 0.007 Pixel,importance 290 0.007 Pixel,importance 381 0.007 Pixel,importance 543 0.006 Pixel,importance 345 0.006 Pixel,importance 541 0.006 Pixel,importance 322 0.006 Pixel,importance 407 0.006 Pixel,importance 403 0.006 Pixel,importance 401 0.006 Pixel,importance 597 0.006 Pixel,importance 153 0.006 Pixel,importance 347 0.006 Pixel,importance 376 0.006 Pixel,importance 435 0.006 Pixel,importance 182 0.006 Pixel,importance 346 0.006 Pixel,importance 156 0.006 Pixel,importance 402 0.006 Pixel,importance 487 0.006 Pixel,importance 240 0.006 Pixel,importance 596 0.006 Pixel,importance 154 0.005 Pixel,importance 210 0.005 Pixel,importance 488 0.005 Pixel,importance 319 0.005 Pixel,importance 323 0.005 Pixel,importance 655 0.005 Pixel,importance 239 0.005 Pixel,importance 152 0.005 Pixel,importance 626 0.005 Pixel,importance 386 0.005 Pixel,importance 540 0.005 Pixel,importance 374 0.005 Pixel,importance 456 0.005 Pixel,importance 567 0.005 Pixel,importance 430 0.005 Pixel,importance 291 0.005 Pixel,importance 657 0.005 Pixel,importance 183 0.005 Pixel,importance 262 0.005 Pixel,importance 400 0.005 Pixel,importance 516 0.005 Pixel,importance 100 0.005 Pixel,importance 431 0.005 Pixel,importance 486 0.005 Pixel,importance 459 0.005 Pixel,importance 513 0.005 Pixel,importance 457 0.005 Pixel,importance 373 0.005 Pixel,importance 432 0.005 Pixel,importance 490 0.005 Pixel,importance 428 0.004 Pixel,importance 379 0.004 Pixel,importance 570 0.004 Pixel,importance 625 0.004 Pixel,importance 658 0.004 Pixel,importance 404 0.004 Pixel,importance 351 0.004 Pixel,importance 299 0.004 Pixel,importance 211 0.004 Pixel,importance 212 0.004 Pixel,importance 184 0.004 Pixel,importance 358 0.004 Pixel,importance 326 0.004 Pixel,importance 298 0.004 Pixel,importance 484 0.004 Pixel,importance 382 0.004 Pixel,importance 271 0.004 Pixel,importance 237 0.004 Pixel,importance 317 0.004 Pixel,importance 517 0.004 Pixel,importance 272 0.004 Pixel,importance 353 0.004 Pixel,importance 263 0.004 Pixel,importance 485 0.004 Pixel,importance 296 0.004 Pixel,importance 539 0.004 Pixel,importance 458 0.004 Pixel,importance 236 0.004 Pixel,importance 270 0.004 Pixel,importance 550 0.003 Pixel,importance 354 0.003 Pixel,importance 551 0.003 Pixel,importance 348 0.003 Pixel,importance 372 0.003 Pixel,importance 595 0.003 Pixel,importance 320 0.003 Pixel,importance 380 0.003 Pixel,importance 268 0.003 Pixel,importance 427 0.003 Pixel,importance 522 0.003 Pixel,importance 464 0.003 Pixel,importance 127 0.003 Pixel,importance 463 0.003 Pixel,importance 654 0.003 Pixel,importance 297 0.003 Pixel,importance 266 0.003 Pixel,importance 151 0.003 Pixel,importance 267 0.003 Pixel,importance 325 0.003 Pixel,importance 181 0.003 Pixel,importance 241 0.003 Pixel,importance 494 0.003 Pixel,importance 352 0.003 Pixel,importance 295 0.003 Pixel,importance 126 0.003 Pixel,importance 207 0.003 Pixel,importance 289 0.003 Pixel,importance 178 0.003 Pixel,importance 269 0.003 Pixel,importance 571 0.003 Pixel,importance 491 0.003 Pixel,importance 414 0.003 Pixel,importance 264 0.003 Pixel,importance 157 0.003 Pixel,importance 436 0.003 Pixel,importance 179 0.003 Pixel,importance 242 0.003 Pixel,importance 101 0.003 Pixel,importance 327 0.003 Pixel,importance 408 0.003 Pixel,importance 215 0.003 Pixel,importance 659 0.003 Pixel,importance 330 0.003 Pixel,importance 544 0.003 Pixel,importance 357 0.003 Pixel,importance 206 0.003 Pixel,importance 324 0.003 Pixel,importance 214 0.003 Pixel,importance 410 0.002 Pixel,importance 573 0.002 Pixel,importance 243 0.002 Pixel,importance 235 0.002 Pixel,importance 466 0.002 Pixel,importance 273 0.002 Pixel,importance 208 0.002 Pixel,importance 99 0.002 Pixel,importance 577 0.002 Pixel,importance 483 0.002 Pixel,importance 180 0.002 Pixel,importance 244 0.002 Pixel,importance 623 0.002 Pixel,importance 209 0.002 Pixel,importance 300 0.002 Pixel,importance 349 0.002 Pixel,importance 98 0.002 Pixel,importance 316 0.002 Pixel,importance 413 0.002 Pixel,importance 213 0.002 Pixel,importance 177 0.002 Pixel,importance 511 0.002 Pixel,importance 579 0.002 Pixel,importance 371 0.002 Pixel,importance 321 0.002 Pixel,importance 627 0.002 Pixel,importance 205 0.002 Pixel,importance 512 0.002 Pixel,importance 216 0.002 Pixel,importance 292 0.002 Pixel,importance 150 0.002 Pixel,importance 455 0.002 Pixel,importance 234 0.002 Pixel,importance 438 0.002 Pixel,importance 128 0.002 Pixel,importance 294 0.002 Pixel,importance 343 0.002 Pixel,importance 572 0.002 Pixel,importance 624 0.002 Pixel,importance 545 0.002 Pixel,importance 521 0.002 Pixel,importance 261 0.002 Pixel,importance 566 0.002 Pixel,importance 344 0.002 Pixel,importance 454 0.002 Pixel,importance 412 0.002 Pixel,importance 124 0.002 Pixel,importance 497 0.002 Pixel,importance 265 0.002 Pixel,importance 185 0.002 Pixel,importance 293 0.002 Pixel,importance 598 0.002 Pixel,importance 328 0.002 Pixel,importance 552 0.002 Pixel,importance 653 0.002 Pixel,importance 355 0.002 Pixel,importance 599 0.002 Pixel,importance 102 0.002 Pixel,importance 158 0.002 Pixel,importance 189 0.002 Pixel,importance 123 0.001 Pixel,importance 630 0.001 Pixel,importance 496 0.001 Pixel,importance 518 0.001 Pixel,importance 441 0.001 Pixel,importance 217 0.001 Pixel,importance 329 0.001 Pixel,importance 493 0.001 Pixel,importance 399 0.001 Pixel,importance 411 0.001 Pixel,importance 245 0.001 Pixel,importance 660 0.001 Pixel,importance 384 0.001 Pixel,importance 301 0.001 Pixel,importance 712 0.001 Pixel,importance 125 0.001 Pixel,importance 383 0.001 Pixel,importance 186 0.001 Pixel,importance 685 0.001 Pixel,importance 510 0.001 Pixel,importance 524 0.001 Pixel,importance 385 0.001 Pixel,importance 686 0.001 Pixel,importance 439 0.001 Pixel,importance 187 0.001 Pixel,importance 426 0.001 Pixel,importance 468 0.001 Pixel,importance 218 0.001 Pixel,importance 576 0.001 Pixel,importance 652 0.001 Pixel,importance 575 0.001 Pixel,importance 684 0.001 Pixel,importance 492 0.001 Pixel,importance 97 0.001 Pixel,importance 538 0.001 Pixel,importance 574 0.001 Pixel,importance 628 0.001 Pixel,importance 581 0.001 Pixel,importance 176 0.001 Pixel,importance 578 0.001 Pixel,importance 683 0.001 Pixel,importance 233 0.001 Pixel,importance 467 0.001 Pixel,importance 288 0.001 Pixel,importance 580 0.001 Pixel,importance 549 0.001 Pixel,importance 160 0.001 Pixel,importance 523 0.001 Pixel,importance 546 0.001 Pixel,importance 356 0.001 Pixel,importance 482 0.001 Pixel,importance 440 0.001 Pixel,importance 259 0.001 Pixel,importance 398 0.001 Pixel,importance 315 0.001 Pixel,importance 622 0.001 Pixel,importance 415 0.001 Pixel,importance 606 0.001 Pixel,importance 651 0.001 Pixel,importance 681 0.001 Pixel,importance 161 0.001 Pixel,importance 387 0.001 Pixel,importance 632 0.001 Pixel,importance 519 0.001 Pixel,importance 274 0.001 Pixel,importance 159 0.001 Pixel,importance 96 0.001 Pixel,importance 601 0.001 Pixel,importance 605 0.001 Pixel,importance 191 0.001 Pixel,importance 495 0.001 Pixel,importance 682 0.001 Pixel,importance 631 0.001 Pixel,importance 219 0.001 Pixel,importance 604 0.001 Pixel,importance 582 0.001 Pixel,importance 188 0.001 Pixel,importance 525 0.001 Pixel,importance 661 0.001 Pixel,importance 302 0.001 Pixel,importance 553 0.001 Pixel,importance 247 0.001 Pixel,importance 469 0.001 Pixel,importance 287 0.001 Pixel,importance 190 0.001 Pixel,importance 370 0.001 Pixel,importance 607 0.001 Pixel,importance 687 0.001 Pixel,importance 600 0.001 Pixel,importance 526 0.001 Pixel,importance 130 0.001 Pixel,importance 129 0.001 Pixel,importance 220 0.001 Pixel,importance 629 0.001 Pixel,importance 260 0.001 Pixel,importance 95 0.001 Pixel,importance 122 0.001 Pixel,importance 594 0.001 Pixel,importance 231 0.001 Pixel,importance 248 0.001 Pixel,importance 275 0.001 Pixel,importance 232 0.001 Pixel,importance 246 0.001 Pixel,importance 131 0.001 Pixel,importance 103 0.001 Pixel,importance 69 0.001 Pixel,importance 481 0.001 Pixel,importance 359 0.0 Pixel,importance 554 0.0 Pixel,importance 498 0.0 Pixel,importance 509 0.0 Pixel,importance 342 0.0 Pixel,importance 149 0.0 Pixel,importance 331 0.0 Pixel,importance 565 0.0 Pixel,importance 175 0.0 Pixel,importance 608 0.0 Pixel,importance 711 0.0 Pixel,importance 204 0.0 Pixel,importance 148 0.0 Pixel,importance 442 0.0 Pixel,importance 94 0.0 Pixel,importance 547 0.0 Pixel,importance 603 0.0 Pixel,importance 609 0.0 Pixel,importance 520 0.0 Pixel,importance 593 0.0 Pixel,importance 453 0.0 Pixel,importance 314 0.0 Pixel,importance 397 0.0 Pixel,importance 708 0.0 Pixel,importance 548 0.0 Pixel,importance 602 0.0 Pixel,importance 634 0.0 Pixel,importance 713 0.0 Pixel,importance 537 0.0 Pixel,importance 162 0.0 Pixel,importance 71 0.0 Pixel,importance 716 0.0 Pixel,importance 717 0.0 Pixel,importance 740 0.0 Pixel,importance 633 0.0 Pixel,importance 230 0.0 Pixel,importance 714 0.0 Pixel,importance 470 0.0 Pixel,importance 635 0.0 Pixel,importance 709 0.0 Pixel,importance 443 0.0 Pixel,importance 388 0.0 Pixel,importance 555 0.0 Pixel,importance 192 0.0 Pixel,importance 527 0.0 Pixel,importance 718 0.0 Pixel,importance 662 0.0 Pixel,importance 680 0.0 Pixel,importance 286 0.0 Pixel,importance 528 0.0 Pixel,importance 163 0.0 Pixel,importance 744 0.0 Pixel,importance 341 0.0 Pixel,importance 556 0.0 Pixel,importance 104 0.0 Pixel,importance 276 0.0 Pixel,importance 132 0.0 Pixel,importance 258 0.0 Pixel,importance 147 0.0 Pixel,importance 202 0.0 Pixel,importance 710 0.0 Pixel,importance 742 0.0 Pixel,importance 564 0.0 Pixel,importance 715 0.0 Pixel,importance 303 0.0 Pixel,importance 203 0.0 Pixel,importance 425 0.0 Pixel,importance 121 0.0 Pixel,importance 592 0.0 Pixel,importance 637 0.0 Pixel,importance 471 0.0 Pixel,importance 743 0.0 Pixel,importance 369 0.0 Pixel,importance 636 0.0 Pixel,importance 610 0.0 Pixel,importance 284 0.0 Pixel,importance 691 0.0 Pixel,importance 68 0.0 Pixel,importance 304 0.0 Pixel,importance 583 0.0 Pixel,importance 285 0.0 Pixel,importance 690 0.0 Pixel,importance 706 0.0 Pixel,importance 739 0.0 Pixel,importance 228 0.0 Pixel,importance 747 0.0 Pixel,importance 688 0.0 Pixel,importance 332 0.0 Pixel,importance 221 0.0 Pixel,importance 499 0.0 Pixel,importance 679 0.0 Pixel,importance 737 0.0 Pixel,importance 193 0.0 Pixel,importance 650 0.0 Pixel,importance 72 0.0 Pixel,importance 201 0.0 Pixel,importance 611 0.0 Pixel,importance 70 0.0 Pixel,importance 164 0.0 Pixel,importance 74 0.0 Pixel,importance 621 0.0 Pixel,importance 689 0.0 Pixel,importance 313 0.0 Pixel,importance 508 0.0 Pixel,importance 105 0.0 Pixel,importance 452 0.0 Pixel,importance 249 0.0 Pixel,importance 120 0.0 Pixel,importance 536 0.0 Pixel,importance 93 0.0 Pixel,importance 256 0.0 Pixel,importance 133 0.0 Pixel,importance 134 0.0 Pixel,importance 277 0.0 Pixel,importance 500 0.0 Pixel,importance 719 0.0 Pixel,importance 67 0.0 Pixel,importance 745 0.0 Pixel,importance 677 0.0 Pixel,importance 66 0.0 Pixel,importance 707 0.0 Pixel,importance 584 0.0 Pixel,importance 665 0.0 Pixel,importance 444 0.0 Pixel,importance 360 0.0 Pixel,importance 73 0.0 Pixel,importance 741 0.0 Pixel,importance 472 0.0 Pixel,importance 735 0.0 Pixel,importance 312 0.0 Pixel,importance 76 0.0 Pixel,importance 92 0.0 Pixel,importance 229 0.0 Pixel,importance 678 0.0 Pixel,importance 648 0.0 Pixel,importance 663 0.0 Pixel,importance 283 0.0 Pixel,importance 738 0.0 Pixel,importance 693 0.0 Pixel,importance 639 0.0 Pixel,importance 620 0.0 Pixel,importance 638 0.0 Pixel,importance 135 0.0 Pixel,importance 692 0.0 Pixel,importance 664 0.0 Pixel,importance 501 0.0 Pixel,importance 396 0.0 Pixel,importance 585 0.0 Pixel,importance 416 0.0 Pixel,importance 563 0.0 Pixel,importance 720 0.0 Pixel,importance 721 0.0 Pixel,importance 75 0.0 Pixel,importance 146 0.0 Pixel,importance 479 0.0 Pixel,importance 771 0.0 Pixel,importance 173 0.0 Pixel,importance 257 0.0 Pixel,importance 174 0.0 Pixel,importance 424 0.0 Pixel,importance 666 0.0 Pixel,importance 591 0.0 Pixel,importance 507 0.0 Pixel,importance 535 0.0 Pixel,importance 119 0.0 Pixel,importance 106 0.0 Pixel,importance 107 0.0 Pixel,importance 746 0.0 Pixel,importance 619 0.0 Pixel,importance 480 0.0 Pixel,importance 340 0.0 Pixel,importance 586 0.0 Pixel,importance 529 0.0 Pixel,importance 736 0.0 Pixel,importance 91 0.0 Pixel,importance 649 0.0 Pixel,importance 557 0.0 Pixel,importance 306 0.0 Pixel,importance 368 0.0 Pixel,importance 450 0.0 Pixel,importance 136 0.0 Pixel,importance 694 0.0 Pixel,importance 640 0.0 Pixel,importance 451 0.0 Pixel,importance 250 0.0 Pixel,importance 478 0.0 Pixel,importance 222 0.0 Pixel,importance 705 0.0 Pixel,importance 77 0.0 Pixel,importance 667 0.0 Pixel,importance 446 0.0 Pixel,importance 473 0.0 Pixel,importance 676 0.0 Pixel,importance 417 0.0 Pixel,importance 41 0.0 Pixel,importance 227 0.0 Pixel,importance 669 0.0 Pixel,importance 612 0.0 Pixel,importance 704 0.0 Pixel,importance 90 0.0 Pixel,importance 172 0.0 Pixel,importance 199 0.0 Pixel,importance 117 0.0 Pixel,importance 0 0.0 Pixel,importance 1 0.0 Pixel,importance 2 0.0 Pixel,importance 3 0.0 Pixel,importance 4 0.0 Pixel,importance 5 0.0 Pixel,importance 6 0.0 Pixel,importance 7 0.0 Pixel,importance 8 0.0 Pixel,importance 9 0.0 Pixel,importance 10 0.0 Pixel,importance 11 0.0 Pixel,importance 12 0.0 Pixel,importance 13 0.0 Pixel,importance 14 0.0 Pixel,importance 15 0.0 Pixel,importance 16 0.0 Pixel,importance 17 0.0 Pixel,importance 18 0.0 Pixel,importance 19 0.0 Pixel,importance 20 0.0 Pixel,importance 21 0.0 Pixel,importance 22 0.0 Pixel,importance 23 0.0 Pixel,importance 24 0.0 Pixel,importance 25 0.0 Pixel,importance 26 0.0 Pixel,importance 27 0.0 Pixel,importance 28 0.0 Pixel,importance 29 0.0 Pixel,importance 30 0.0 Pixel,importance 31 0.0 Pixel,importance 32 0.0 Pixel,importance 33 0.0 Pixel,importance 34 0.0 Pixel,importance 35 0.0 Pixel,importance 36 0.0 Pixel,importance 37 0.0 Pixel,importance 38 0.0 Pixel,importance 39 0.0 Pixel,importance 40 0.0 Pixel,importance 42 0.0 Pixel,importance 43 0.0 Pixel,importance 44 0.0 Pixel,importance 45 0.0 Pixel,importance 46 0.0 Pixel,importance 47 0.0 Pixel,importance 48 0.0 Pixel,importance 49 0.0 Pixel,importance 50 0.0 Pixel,importance 51 0.0 Pixel,importance 52 0.0 Pixel,importance 53 0.0 Pixel,importance 54 0.0 Pixel,importance 55 0.0 Pixel,importance 56 0.0 Pixel,importance 57 0.0 Pixel,importance 58 0.0 Pixel,importance 59 0.0 Pixel,importance 60 0.0 Pixel,importance 61 0.0 Pixel,importance 62 0.0 Pixel,importance 63 0.0 Pixel,importance 64 0.0 Pixel,importance 65 0.0 Pixel,importance 78 0.0 Pixel,importance 79 0.0 Pixel,importance 80 0.0 Pixel,importance 81 0.0 Pixel,importance 82 0.0 Pixel,importance 83 0.0 Pixel,importance 84 0.0 Pixel,importance 85 0.0 Pixel,importance 86 0.0 Pixel,importance 87 0.0 Pixel,importance 88 0.0 Pixel,importance 89 0.0 Pixel,importance 108 0.0 Pixel,importance 109 0.0 Pixel,importance 110 0.0 Pixel,importance 111 0.0 Pixel,importance 112 0.0 Pixel,importance 113 0.0 Pixel,importance 114 0.0 Pixel,importance 115 0.0 Pixel,importance 116 0.0 Pixel,importance 118 0.0 Pixel,importance 137 0.0 Pixel,importance 138 0.0 Pixel,importance 139 0.0 Pixel,importance 140 0.0 Pixel,importance 141 0.0 Pixel,importance 142 0.0 Pixel,importance 143 0.0 Pixel,importance 144 0.0 Pixel,importance 145 0.0 Pixel,importance 165 0.0 Pixel,importance 166 0.0 Pixel,importance 167 0.0 Pixel,importance 168 0.0 Pixel,importance 169 0.0 Pixel,importance 170 0.0 Pixel,importance 171 0.0 Pixel,importance 194 0.0 Pixel,importance 195 0.0 Pixel,importance 196 0.0 Pixel,importance 197 0.0 Pixel,importance 198 0.0 Pixel,importance 200 0.0 Pixel,importance 223 0.0 Pixel,importance 224 0.0 Pixel,importance 225 0.0 Pixel,importance 226 0.0 Pixel,importance 251 0.0 Pixel,importance 252 0.0 Pixel,importance 253 0.0 Pixel,importance 254 0.0 Pixel,importance 255 0.0 Pixel,importance 278 0.0 Pixel,importance 279 0.0 Pixel,importance 280 0.0 Pixel,importance 281 0.0 Pixel,importance 282 0.0 Pixel,importance 305 0.0 Pixel,importance 307 0.0 Pixel,importance 308 0.0 Pixel,importance 309 0.0 Pixel,importance 310 0.0 Pixel,importance 311 0.0 Pixel,importance 333 0.0 Pixel,importance 334 0.0 Pixel,importance 335 0.0 Pixel,importance 336 0.0 Pixel,importance 337 0.0 Pixel,importance 338 0.0 Pixel,importance 339 0.0 Pixel,importance 361 0.0 Pixel,importance 362 0.0 Pixel,importance 363 0.0 Pixel,importance 364 0.0 Pixel,importance 365 0.0 Pixel,importance 366 0.0 Pixel,importance 367 0.0 Pixel,importance 389 0.0 Pixel,importance 390 0.0 Pixel,importance 391 0.0 Pixel,importance 392 0.0 Pixel,importance 393 0.0 Pixel,importance 394 0.0 Pixel,importance 395 0.0 Pixel,importance 418 0.0 Pixel,importance 419 0.0 Pixel,importance 420 0.0 Pixel,importance 421 0.0 Pixel,importance 422 0.0 Pixel,importance 423 0.0 Pixel,importance 445 0.0 Pixel,importance 447 0.0 Pixel,importance 448 0.0 Pixel,importance 449 0.0 Pixel,importance 474 0.0 Pixel,importance 475 0.0 Pixel,importance 476 0.0 Pixel,importance 477 0.0 Pixel,importance 502 0.0 Pixel,importance 503 0.0 Pixel,importance 504 0.0 Pixel,importance 505 0.0 Pixel,importance 506 0.0 Pixel,importance 530 0.0 Pixel,importance 531 0.0 Pixel,importance 532 0.0 Pixel,importance 533 0.0 Pixel,importance 534 0.0 Pixel,importance 558 0.0 Pixel,importance 559 0.0 Pixel,importance 560 0.0 Pixel,importance 561 0.0 Pixel,importance 562 0.0 Pixel,importance 587 0.0 Pixel,importance 588 0.0 Pixel,importance 589 0.0 Pixel,importance 590 0.0 Pixel,importance 613 0.0 Pixel,importance 614 0.0 Pixel,importance 615 0.0 Pixel,importance 616 0.0 Pixel,importance 617 0.0 Pixel,importance 618 0.0 Pixel,importance 641 0.0 Pixel,importance 642 0.0 Pixel,importance 643 0.0 Pixel,importance 644 0.0 Pixel,importance 645 0.0 Pixel,importance 646 0.0 Pixel,importance 647 0.0 Pixel,importance 668 0.0 Pixel,importance 670 0.0 Pixel,importance 671 0.0 Pixel,importance 672 0.0 Pixel,importance 673 0.0 Pixel,importance 674 0.0 Pixel,importance 675 0.0 Pixel,importance 695 0.0 Pixel,importance 696 0.0 Pixel,importance 697 0.0 Pixel,importance 698 0.0 Pixel,importance 699 0.0 Pixel,importance 700 0.0 Pixel,importance 701 0.0 Pixel,importance 702 0.0 Pixel,importance 703 0.0 Pixel,importance 722 0.0 Pixel,importance 723 0.0 Pixel,importance 724 0.0 Pixel,importance 725 0.0 Pixel,importance 726 0.0 Pixel,importance 727 0.0 Pixel,importance 728 0.0 Pixel,importance 729 0.0 Pixel,importance 730 0.0 Pixel,importance 731 0.0 Pixel,importance 732 0.0 Pixel,importance 733 0.0 Pixel,importance 734 0.0 Pixel,importance 748 0.0 Pixel,importance 749 0.0 Pixel,importance 750 0.0 Pixel,importance 751 0.0 Pixel,importance 752 0.0 Pixel,importance 753 0.0 Pixel,importance 754 0.0 Pixel,importance 755 0.0 Pixel,importance 756 0.0 Pixel,importance 757 0.0 Pixel,importance 758 0.0 Pixel,importance 759 0.0 Pixel,importance 760 0.0 Pixel,importance 761 0.0 Pixel,importance 762 0.0 Pixel,importance 763 0.0 Pixel,importance 764 0.0 Pixel,importance 765 0.0 Pixel,importance 766 0.0 Pixel,importance 767 0.0 Pixel,importance 768 0.0 Pixel,importance 769 0.0 Pixel,importance 770 0.0 Pixel,importance 772 0.0 Pixel,importance 773 0.0 Pixel,importance 774 0.0 Pixel,importance 775 0.0 Pixel,importance 776 0.0 Pixel,importance 777 0.0 Pixel,importance 778 0.0 Pixel,importance 779 0.0 Pixel,importance 780 0.0 Pixel,importance 781 0.0 Pixel,importance 782 0.0 Pixel,importance 783 0.0
# Reshape the feature importance to be 28x28 and plot using px.imshow
result = importanceByName.items()
# Convert object to a list
importance_data = list(result)
# Convert list to an array
importance = np.array(importance_data)
importance = importance[:,1]
#importance
px.imshow(importance.reshape(28,28))
#
# Plot a single digit (a single row of X) using px.imshow
px.imshow(X[0].reshape(28,28))
The shape of the feature importance - as see from the imshow heatmap - might make sense. But is there a way to display this somehow. We might expect that the more a given pixel is used might be related to how often it is used across all of the digits. Let's see if this is the case.
Find a way to calculate the average occupancy of every pixel, across all of the images. Then:
np.where(X[0] != 0) #this finds where there is a nonzero pixel in an image
(array([ 73, 74, 100, 101, 102, 127, 128, 129, 130, 155, 156, 157, 182,
183, 184, 210, 211, 212, 237, 238, 239, 265, 266, 267, 293, 294,
320, 321, 322, 325, 326, 327, 328, 347, 348, 349, 351, 352, 353,
354, 355, 356, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384,
402, 403, 404, 405, 406, 407, 408, 410, 411, 412, 430, 431, 432,
433, 434, 435, 438, 439, 440, 458, 459, 460, 461, 462, 465, 466,
467, 486, 487, 488, 489, 492, 493, 494, 495, 513, 514, 515, 516,
519, 520, 521, 541, 542, 543, 544, 545, 546, 547, 548, 569, 570,
571, 572, 573, 574, 575, 598, 599, 600, 601, 602]),)
np.where(X[0] != 0)[0][0] #returns the index of the pixel
73
pixels = []
number_of_images = X.shape[0]
for i in range(number_of_images):
theres_a_pixel_here =np.where(X[i] != 0)[0] #looks at all the pixels in an image
size = np.size(theres_a_pixel_here) #gets size of array
for j in range(size):
pixels.append(theres_a_pixel_here[j]) #returns the index of the pixel
from collections import Counter
pixels = np.array(pixels)
c = Counter(pixels) #outputs a dictionary of a value and how many times it appeared in the array
pixel_here = []
for i in range(num_features):
pixel_here.append(c[i])
pixel_here = np.array(pixel_here)
average_occupancy = pixel_here/number_of_images
px.imshow(average_occupancy.reshape(28,28))
px.imshow(importance.reshape(28,28))
import matplotlib.pyplot as plt
plt.scatter(importance,average_occupancy,marker='.')
plt.xlabel('Importance Value')
plt.ylabel('Average Occupancy')
plt.title('Importance of Pixel vs Average Occupancy of Pixel')
plt.show()